//
// Created by corey on 2021/4/20.
//
#include <iostream>
#include <NvOnnxParser.h>

#include "common.h"
#include "plugins/top_pool_plugin.h"
#include "plugins/bottom_pool_plugin.h"
#include "plugins/left_pool_plugin.h"
#include "plugins/right_pool_plugin.h"

#include "xtensor/xnpy.hpp"
#include "xtensor/xarray.hpp"
#include "xtensor/xtensor.hpp"
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"
#include "xtensor/xindex_view.hpp"
#include "xtensor/xrandom.hpp"
#include "xtensor/xadapt.hpp"
#include "xtensor/xsort.hpp"

#include "xtensor/xnpy.hpp"
#include "xtensor/xarray.hpp"
#include "xtensor/xtensor.hpp"
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"
#include "xtensor/xindex_view.hpp"
#include "xtensor/xrandom.hpp"
#include "xtensor/xadapt.hpp"
#include "xtensor/xsort.hpp"

#include "Utils.h"

using namespace nvinfer1;
using namespace std;

simplelogger::Logger *logger = simplelogger::LoggerFactory::CreateConsoleLogger(simplelogger::TRACE);

namespace
{
    hackathon::Logger gLogger;
}

int main(int argc, char *argv[])
{
    if (argc != 3)
    {
        cout << "missing plan file and input npy\n";
        return 0;
    }

    const char *planFile = argv[1];
    const char *npyFile = argv[2];
    BufferedFileReader reader(planFile);

    uint8_t *pBuf = nullptr;
    uint32_t nSize = 0;
    reader.GetBuffer(&pBuf, &nSize);
    if (!nSize)
    {
        return -1;
    }

    IRuntime *runtime = createInferRuntime(gLogger);
    ICudaEngine *engine = runtime->deserializeCudaEngine(pBuf, nSize);
    runtime->destroy();
    if (!engine)
    {
        cout << "No engine created\n";
        return -1;
    }

    auto context = engine->createExecutionContext();
    cudaStream_t stream;
    cudaStreamCreate(&stream);

    std::vector<void *> dpBuf;
    std::vector<void *> hostBuf;;
    std::vector<pair<int, Dims>> hostDims;
    const int nBuf = engine->getNbBindings();
    for (int i = 0; i < nBuf; i++)
    {
        int size = 1;
        const Dims &dims = context->getBindingDimensions(i);
        for (int j = 0; j < dims.nbDims; j++)
        {
            size *= dims.d[j];
        }
        //printf("[%d]: %d,%d,%d,%d\n", i, dims.d[0], dims.d[1], dims.d[2], dims.d[3]);
        void *ptr;
        cudaMalloc(&ptr, size * sizeof(float));
        dpBuf.push_back(ptr);
        if (i > 0)
        {
            hostDims.push_back(make_pair(size, dims));
            hostBuf.push_back(malloc(size * sizeof(float)));
        }
    }

    xt::xarray<float> output_map = xt::load_npy<float>(npyFile);

    int loop = 1;
    while (loop--)
    {
        auto t0 = std::chrono::steady_clock::now();

        cudaMemcpyAsync(dpBuf[0], output_map.data(), output_map.size() * sizeof(float), cudaMemcpyHostToDevice, stream);

        context->enqueue(1, dpBuf.data(), stream, nullptr);

        cudaStreamSynchronize(stream);

        auto t1 = std::chrono::steady_clock::now();
        cout << "cost time: " << chrono::duration_cast<chrono::milliseconds>(t1 - t0).count() << " ms\n";
    }

    for (int i = 0; i < hostBuf.size(); i++)
    {
        cudaMemcpy(hostBuf[i], dpBuf[i + 1], hostDims[i].first * sizeof(float), cudaMemcpyDeviceToHost);
        auto dims = hostDims[i].second;
        std::string name = to_string(dims.d[0]) + "x" + to_string(dims.d[1]) + "x" + to_string(dims.d[2]) + "x" +
                           to_string(dims.d[3]);
        std::vector<std::size_t> shape = {static_cast<unsigned long>(dims.d[0]),
                                          static_cast<unsigned long>(dims.d[1]),
                                          static_cast<unsigned long>(dims.d[2]),
                                          static_cast<unsigned long>(dims.d[3])};
        auto result = xt::adapt((float *) hostBuf[i], shape);
        xt::dump_npy("result_" + to_string(i) + "_s" + name + ".npy", result);
    }


    for (auto &ptr : dpBuf) cudaFree(ptr);
    for (auto &ptr : hostBuf) free(ptr);

    if (stream) cudaStreamDestroy(stream);

    return 0;
}

